library(tidyverse)
library(lubridate)

language_preference_by_cbg <- read_csv("../data/language_preference_by_cbg.csv")
acs_surnames <- read_rds("../data/surnames.rds")
acs_first_names <- read_rds("../data/first_names.rds")

# Expects a dataframe with `birthdate` (Date), `first_name` (char),
# `last_name` (char), and `GEOID` (char) columns.
# Returns the dataframe with `age`, `spanish_first_name_score`,
# `spanish_surname_score`, `spanish_address_score` columns.
get_component_language_scores <- function(df) {
  # Check that the required columns exist.
  if (
    length(
      setdiff(
        c("birthdate", "first_name", "last_name", "GEOID"),
        names(df)
      )
    ) > 0
  ) {
    print("Missing one of: birthdate, first_name, last_name, GEOID columns")
    return(df)
  }

  # Calculate age. If `Opened.Date` is not available, use the present date.
  if ("Opened.Date" %in% colnames(df)) {
    df <- df %>%
      mutate(
        Opened.Date = as.Date(Opened.Date, format = "%m/%d/%Y"),
        age = time_length(Opened.Date - birthdate, "years")
      )
  } else {
    df <- df %>%
      mutate(
        age = time_length(Sys.Date() - birthdate, "years")
      )
  }

  # Uppercase names. Make last names alphanumeric only, and truncate first names
  # after the first space to avoid middle names/initials.
  df <- df %>%
    mutate(
      last_name = str_replace_all(toupper(last_name), "[^[:alnum:]]", ""),
      first_name = str_replace(toupper(first_name), "(?s) .*", ""),
    ) %>%
    left_join(
      acs_surnames %>%
        select(last_name = NAME, spanish_surname_score = PCTHISPANIC)
    ) %>%
    mutate(
      spanish_surname_score = ifelse(
        is.na(spanish_surname_score),
        .08226,
        spanish_surname_score / 100
      )
    ) %>%
    left_join(
      acs_first_names %>%
        select(first_name = NAME, spanish_first_name_score = PCTHISPANIC)
    ) %>%
    mutate(
      spanish_first_name_score = ifelse(
        is.na(spanish_first_name_score),
        tail(acs_first_names, 1)$PCTHISPANIC / 100,
        spanish_first_name_score / 100
      )
    )

  # Match by GEOID.
  df <- df %>%
    left_join(
      language_preference_by_cbg %>%
        select(GEOID, spanish_address_score = spanish_probability)
    )

  return(df)
}

# Cutoffs algorithm. Expects a threshold risk score and threshold bin count
# which are used to determine which individuals are flagged as likely Spanish
# speakers, and a dataframe of bins along age, address score, first name score,
# last name score (the training data the cutoffs are calibrated on).
# Setting `plot = TRUE` will print a plot of the cutoffs on the distribution
# represented by `binned_train_df`.
set_cutoffs <- function(
  threshold_score,
  threshold_bin_count,
  binned_train_df,
  plot = FALSE
) {
  data <- binned_train_df %>%
    group_by(
      first_name_score_bin,
      last_name_score_bin,
      age_bin,
      addr_score_bin
    ) %>%
    summarize(
      perc_spanish = mean(is_spanish),
      bin_count = n()
    )

  age_bins <- unique(data$age_bin)
  addr_score_bins <- unique(data$addr_score_bin)
  first_name_score_bins <- unique(data$first_name_score_bin)
  last_name_score_bins <- unique(data$last_name_score_bin)

  data <- expand.grid(
    first_name_score_bin = first_name_score_bins,
    last_name_score_bin = last_name_score_bins,
    age_bin = age_bins,
    addr_score_bin = addr_score_bins
  ) %>%
    left_join(
      data,
      by = c(
        "first_name_score_bin",
        "last_name_score_bin",
        "age_bin",
        "addr_score_bin"
      )
    ) %>%
    mutate(
      perc_spanish = ifelse(is.na(perc_spanish), -1, perc_spanish),
      bin_count = ifelse(is.na(bin_count), -1, bin_count)
    )

  full_table <-
    expand.grid(
      age_bin = unique(data$age_bin),
      addr_score_bin = unique(data$addr_score_bin)
    )

  age_bins <- unique(data$age_bin)
  addr_score_bins <- unique(data$addr_score_bin)
  first_name_score_bins <- unique(data$first_name_score_bin)
  last_name_score_bins <- unique(data$last_name_score_bin)

  data_cutoff <- NULL

  for (last_name_score in seq_len(length(last_name_score_bins))) {

    for (first_name_score in seq_len(length(first_name_score_bins))) {

      data_full <-
        full_table %>%
        left_join(
          data %>%
            filter(
              last_name_score_bin == last_name_score_bins[last_name_score]
            ) %>%
            filter(
              first_name_score_bin == first_name_score_bins[first_name_score]
            ),
          by = c("age_bin", "addr_score_bin")
        )

      threshold_column <- length(age_bins) + 1

      for (address in seq_len(length(addr_score_bins))) {

        for (age in seq_len(length(age_bins))) {

          temp_perc_spanish <- data_full %>%
            filter(addr_score_bin == addr_score_bins[address]) %>%
            filter(age_bin == age_bins[age]) %>%
            pull(perc_spanish)

          temp_bin_count <- data_full %>%
            filter(addr_score_bin == addr_score_bins[address]) %>%
            filter(age_bin == age_bins[age]) %>%
            pull(bin_count)

          if (
            (
              temp_perc_spanish >= threshold_score &
              temp_bin_count >= threshold_bin_count
            ) | age >= threshold_column
          ) {

            data_cutoff <-
              data_cutoff %>%
              rbind(
                data.frame(
                  last_name_score_bin = last_name_score_bins[last_name_score],
                  first_name_score_bin =
                    first_name_score_bins[first_name_score],
                  addr_score_bin = addr_score_bins[address],
                  age_bin = age_bins[age],
                  perc_spanish = temp_perc_spanish,
                  bin_count = temp_bin_count
                )
              )

            if (age < threshold_column) threshold_column <- age

          }

        }

      }

    }

  }

  if (is_null(data_cutoff)) {
    data <- data %>%
      mutate(cutoff = "Out")
  } else {
    data <- data %>%
      left_join(
        data_cutoff %>%
          select(-perc_spanish, -bin_count) %>%
          mutate(cutoff = TRUE),
        by = c(
          "last_name_score_bin",
          "first_name_score_bin",
          "addr_score_bin",
          "age_bin"
        )
      ) %>%
      mutate(
        cutoff = ifelse(
          !is.na(cutoff),
          "In",
          "Out"
        )
      )
  }

  # For each subgroup (by first name and last name bins),
  # get the average score for cells in and out of the boundary.
  for (b1 in first_name_score_bins) {
    for (b2 in last_name_score_bins) {
      avg_in_perc_spanish <- data %>%
        filter(
          b1 == first_name_score_bin,
          b2 == last_name_score_bin,
          perc_spanish >= 0,
          cutoff == "In"
        ) %>%
        pull(perc_spanish) %>%
        mean()
      avg_out_perc_spanish <- data %>%
        filter(
          b1 == first_name_score_bin,
          b2 == last_name_score_bin,
          perc_spanish >= 0,
          cutoff == "Out"
        ) %>%
        pull(perc_spanish) %>%
        mean()
      if (
        !is_empty(
          which(
            data$first_name_score_bin == b1 &
            data$last_name_score_bin == b2 &
            data$perc_spanish == -1 &
            data$cutoff == "In"
          )
        )
      ) {
        data[
          which(
            data$first_name_score_bin == b1 &
              data$last_name_score_bin == b2 &
              data$perc_spanish == -1 &
              data$cutoff == "In"
          ),
        ]$perc_spanish <- avg_in_perc_spanish
      }
      if (
        !is_empty(
          which(
            data$first_name_score_bin == b1 &
            data$last_name_score_bin == b2 &
            data$perc_spanish == -1 &
            data$cutoff == "Out"
          )
        )
      ) {
        data[
          which(
            data$first_name_score_bin == b1 &
              data$last_name_score_bin == b2 &
              data$perc_spanish == -1 &
              data$cutoff == "Out"
          ),
        ]$perc_spanish <- avg_out_perc_spanish
      }
    }
  }

  data <- data %>%
    mutate(
      bin_count = ifelse(bin_count == -1, NA, bin_count)
    )

  if (plot) {
    cutoff_plot <- data %>%
      ggplot(
        aes(x = age_bin,
            y = addr_score_bin,
            fill = perc_spanish * 100,
            size = bin_count,
            colour = cutoff)
      ) +
      geom_tile(size = 1, alpha = 0) +
      geom_point(shape = 21, color = "black") +
      xlab("Age bin") +
      ylab("Address score bin") +
      labs(fill = "% Spanish") +
      scale_fill_distiller(palette = "RdYlBu", direction = 1) +
      scale_x_discrete(guide = guide_axis(angle = 45)) +
      theme(
        axis.text = element_text(size = 14),
        axis.title = element_text(face = "italic"),
        legend.position = "bottom",
        legend.text = element_text(size = 16),
        panel.border = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        axis.line = element_line(color = "black"),
        strip.text.x = element_text(size = 12),
        strip.text.y = element_text(size = 14)
      ) +
      ggtitle(
        paste0(
          "Distribution of Spanish speakers with cutoff ",
          round(threshold_score, 2)
        )
      ) +
      scale_size(range = c(.1, 30)) +
      scale_alpha_discrete(range = c(1, 0.15)) +
      facet_grid(
        fct_rev(last_name_score_bin) ~ first_name_score_bin,
        labeller = label_both
      )
    print(cutoff_plot)
  }

  return(data)
}